# coding:utf-8
'''
author:wangyi
'''
import pickle
import numpy as np
from tensorflow.contrib import keras as kr

'''
1.产生供模型训练数据批次（batch） x_batch,y_batch
2.训练流程 数据“喂”给模型，降低损失
3.评价 多少步 准确率 提升 损失 降低 F1
4.保存/恢复模型
'''

def generate_data_batch(data_path,word2id_path,batch_size,maxlen=200):

    # label2id = {'体育':0,'娱乐':1}
    label2id, word2id = pickle.load(open(word2id_path,'rb'))
    with open(data_path,encoding='utf-8') as f:
        lines = f.readlines()
        # batch_size = 16
        # len(lines) = 100
        # batch_nums = 7
        batch_nums = (len(lines) // batch_size)+1
        for i in range(batch_nums):
                      # lines[0:16]
                      # lines[16:32]
            batches = lines[i*batch_size:(i+1)*batch_size]
            x_batch,y_batch = [],[]
            for line in batches:
                x = []
                # [0 0 0 0 0,...0]
                y = np.zeros(shape=[len(label2id.keys())])
                # '体育'
                label,content = line.strip().split('\t')
                # y[0] = 1
                # [1 0 0 0 0 0]
                y[label2id[label]] = 1
                y_batch.append(y)
                for word in content.split(' '):
                    if word in word2id.keys():
                        # <PAD>
                        # word2id = {’我‘:1,’中国‘:2,....}
                        # [1,4,2....]
                        x.append(word2id[word])
                x_batch.append(x)
                x_batch = kr.preprocessing.sequence.pad_sequences(x_batch,maxlen=maxlen,padding='post',truncating='post')
            yield x_batch,y_batch




